{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NXTSugt6ieXh"
      },
      "source": [
        "## Training CBoW Model\n",
        "\n",
        "This notebooks is a part of [AI for Beginners Curriculum](http://aka.ms/ai-beginners)\n",
        "\n",
        "In this example, we will look at training CBoW language model to get our own Word2Vec embedding space. We will use AG News dataset as the source of text."
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "import torchtext\n",
        "import os\n",
        "import collections\n",
        "import builtins\n",
        "import random\n",
        "import numpy as np"
      ],
      "metadata": {
        "id": "q-UiiJUKaxHj"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
      ],
      "metadata": {
        "id": "TFbR8CZaTZ1q"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "First let's load our dataset and define tokenizer and vocabulary. We will set `vocab_size` to 5000 to limit computations a bit."
      ],
      "metadata": {
        "id": "HIwC7lI5T-ov"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def load_dataset(ngrams = 1, min_freq = 1, vocab_size = 5000 , lines_cnt = 500):\n",
        "    tokenizer = torchtext.data.utils.get_tokenizer('basic_english')\n",
        "    print(\"Loading dataset...\")\n",
        "    test_dataset, train_dataset  = torchtext.datasets.AG_NEWS(root='./data')\n",
        "    train_dataset = list(train_dataset)\n",
        "    test_dataset = list(test_dataset)\n",
        "    classes = ['World', 'Sports', 'Business', 'Sci/Tech']\n",
        "    print('Building vocab...')\n",
        "    counter = collections.Counter()\n",
        "    for i, (_, line) in enumerate(train_dataset):\n",
        "        counter.update(torchtext.data.utils.ngrams_iterator(tokenizer(line),ngrams=ngrams))\n",
        "        if i == lines_cnt:\n",
        "            break\n",
        "    vocab = torchtext.vocab.Vocab(collections.Counter(dict(counter.most_common(vocab_size))), min_freq=min_freq)\n",
        "    return train_dataset, test_dataset, classes, vocab, tokenizer"
      ],
      "metadata": {
        "id": "wdZuygtgiuLG"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "train_dataset, test_dataset, _, vocab, tokenizer = load_dataset()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "4d1nU1gsivGu",
        "outputId": "949fe272-ae0e-49f5-c373-6703458b3a74"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Loading dataset...\n",
            "Building vocab...\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "def encode(x, vocabulary, tokenizer = tokenizer):\n",
        "    return [vocabulary[s] for s in tokenizer(x)]"
      ],
      "metadata": {
        "id": "1XDYNhG8ToFV"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LIlQk6_PaHVY"
      },
      "source": [
        "## CBoW Model\n",
        "\n",
        "CBoW learns to predict a word based on the $2N$ neighboring words. For example, when $N=1$, we will get the following pairs from the sentence *I like to train networks*: (like,I), (I, like), (to, like), (like,to), (train,to), (to, train), (networks, train), (train,networks). Here, first word is the neighboring word used as an input, and second word is the one we are predicting.\n",
        "\n",
        "To build a network to predict next word, we will need to supply neighboring word as input, and get word number as output. The architecture of CBoW network is the following:\n",
        "\n",
        "* Input word is passed through the embedding layer. This very embedding layer would be our Word2Vec embedding, thus we will define it separately as `embedder` variable. We will use embedding size = 30 in this example, even though you might want to experiment with higher dimensions (real word2vec has 300)\n",
        "* Embedding vector would then be passed to a linear layer that will predict output word. Thus it has the `vocab_size` neurons.\n",
        "\n",
        "For the output, if we use `CrossEntropyLoss` as loss function, we would also have to provide just word numbers as expected results, without one-hot encoding."
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "vocab_size = len(vocab)\n",
        "\n",
        "embedder = torch.nn.Embedding(num_embeddings = vocab_size, embedding_dim = 30)\n",
        "model = torch.nn.Sequential(\n",
        "    embedder,\n",
        "    torch.nn.Linear(in_features = 30, out_features = vocab_size),\n",
        ")\n",
        "\n",
        "print(model)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "akKTcKQKkfl2",
        "outputId": "da687e3e-a8ec-4c1a-e456-ab8cd6ac7dad"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Sequential(\n",
            "  (0): Embedding(5002, 30)\n",
            "  (1): Linear(in_features=30, out_features=5002, bias=True)\n",
            ")\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Nud6jgGPaHVa"
      },
      "source": [
        "## Preparing Training Data\n",
        "\n",
        "Now let's program the main function that will compute CBoW word pairs from text. This function will allow us to specify window size, and will return a set of pairs - input and output word. Note that this function can be used on words, as well as on vectors/tensors - which will allow us to encode the text, before passing it to `to_cbow` function."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "x-dsXygOieXn",
        "outputId": "c2218280-e540-40ba-9546-efe48d0d714f"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[['like', 'I'], ['to', 'I'], ['I', 'like'], ['to', 'like'], ['train', 'like'], ['I', 'to'], ['like', 'to'], ['train', 'to'], ['networks', 'to'], ['like', 'train'], ['to', 'train'], ['networks', 'train'], ['to', 'networks'], ['train', 'networks']]\n",
            "[[232, 172], [5, 172], [172, 232], [5, 232], [0, 232], [172, 5], [232, 5], [0, 5], [1202, 5], [232, 0], [5, 0], [1202, 0], [5, 1202], [0, 1202]]\n"
          ]
        }
      ],
      "source": [
        "def to_cbow(sent,window_size=2):\n",
        "    res = []\n",
        "    for i,x in enumerate(sent):\n",
        "        for j in range(max(0,i-window_size),min(i+window_size+1,len(sent))):\n",
        "            if i!=j:\n",
        "                res.append([sent[j],x])\n",
        "    return res\n",
        "\n",
        "print(to_cbow(['I','like','to','train','networks']))\n",
        "print(to_cbow(encode('I like to train networks', vocab)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XVaaDLjaaHVb"
      },
      "source": [
        "Let's prepare the training dataset. We will go through all news, call `to_cbow` to get the list of word pairs, and add those pairs to `X` and `Y`. For the sake of time, we will only consider first 10k news items - you can easily remove the limitation in case you have more time to wait, and want to get better embeddings :)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "54b-Gd9TieXo"
      },
      "outputs": [],
      "source": [
        "X = []\n",
        "Y = []\n",
        "for i, x in zip(range(10000), train_dataset):\n",
        "    for w1, w2 in to_cbow(encode(x[1], vocab), window_size = 5):\n",
        "        X.append(w1)\n",
        "        Y.append(w2)\n",
        "\n",
        "X = torch.tensor(X)\n",
        "Y = torch.tensor(Y)"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "We will also convert that data to one dataset, and create dataloader:"
      ],
      "metadata": {
        "id": "cwWy0PzXWhN5"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "class SimpleIterableDataset(torch.utils.data.IterableDataset):\n",
        "    def __init__(self, X, Y):\n",
        "        super(SimpleIterableDataset).__init__()\n",
        "        self.data = []\n",
        "        for i in range(len(X)):\n",
        "            self.data.append( (Y[i], X[i]) )\n",
        "        random.shuffle(self.data)\n",
        "\n",
        "    def __iter__(self):\n",
        "        return iter(self.data)"
      ],
      "metadata": {
        "id": "mfoAcGPFZU8p"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e4NQ_-5waHVc"
      },
      "source": [
        "We will also convert that data to one dataset, and create dataloader:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AbLUcojlieXo"
      },
      "outputs": [],
      "source": [
        "ds = SimpleIterableDataset(X, Y)\n",
        "dl = torch.utils.data.DataLoader(ds, batch_size = 256)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pKQr7sXeaHVc"
      },
      "source": [
        "Now let's do the actual training. We will use `SGD` optimizer with pretty high learning rate. You can also try playing around with other optimizers, such as `Adam`. We will train for 10 epochs to begin with - and you can re-run this cell if you want even lower loss."
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "def train_epoch(net, dataloader, lr = 0.01, optimizer = None, loss_fn = torch.nn.CrossEntropyLoss(), epochs = None, report_freq = 1):\n",
        "    optimizer = optimizer or torch.optim.Adam(net.parameters(), lr = lr)\n",
        "    loss_fn = loss_fn.to(device)\n",
        "    net.train()\n",
        "\n",
        "    for i in range(epochs):\n",
        "        total_loss, j = 0, 0, \n",
        "        for labels, features in dataloader:\n",
        "            optimizer.zero_grad()\n",
        "            features, labels = features.to(device), labels.to(device)\n",
        "            out = net(features)\n",
        "            loss = loss_fn(out, labels)\n",
        "            loss.backward()\n",
        "            optimizer.step()\n",
        "            total_loss += loss\n",
        "            j += 1\n",
        "        if i % report_freq == 0:\n",
        "            print(f\"Epoch: {i+1}: loss={total_loss.item()/j}\")\n",
        "\n",
        "    return total_loss.item()/j"
      ],
      "metadata": {
        "id": "HeeCYKr_KF1w"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "train_epoch(net = model, dataloader = dl, optimizer = torch.optim.SGD(model.parameters(), lr = 0.1), loss_fn = torch.nn.CrossEntropyLoss(), epochs = 10)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "KVgwGtDHgDlT",
        "outputId": "2447833f-f0e3-4566-c33d-addbfe2f451d"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch: 1: loss=5.664632366860172\n",
            "Epoch: 2: loss=5.632101973960962\n",
            "Epoch: 3: loss=5.610399051405015\n",
            "Epoch: 4: loss=5.594621561080262\n",
            "Epoch: 5: loss=5.582538017415446\n",
            "Epoch: 6: loss=5.572900234519603\n",
            "Epoch: 7: loss=5.564951676341915\n",
            "Epoch: 8: loss=5.558288112064614\n",
            "Epoch: 9: loss=5.552576955031129\n",
            "Epoch: 10: loss=5.547634165194347\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "5.547634165194347"
            ]
          },
          "metadata": {},
          "execution_count": 16
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "W8u2qXZmaHVd"
      },
      "source": [
        "## Trying out Word2Vec\n",
        "\n",
        "To use Word2Vec, let's extract vectors corresponding to all words in our vocabulary:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "r8TatcXjkU_t"
      },
      "outputs": [],
      "source": [
        "vectors = torch.stack([embedder(torch.tensor(vocab[s])) for s in vocab.itos], 0)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3OcX21UOaHVd"
      },
      "source": [
        "Let's see, for example, how the word **Paris** is encoded into a vector:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "bz6tAeLzieXp",
        "outputId": "5b20850e-4342-45e9-f840-cfac2b4d61d8"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "tensor([-0.0915,  2.1224, -0.0281, -0.6819,  1.1219,  0.6458, -1.3704, -1.3314,\n",
            "        -1.1437,  0.4496,  0.2301, -0.3515, -0.8485,  1.0481,  0.4386, -0.8949,\n",
            "         0.5644,  1.0939, -2.5096,  3.2949, -0.2601, -0.8640,  0.1421, -0.0804,\n",
            "        -0.5083, -1.0560,  0.9753, -0.5949, -1.6046,  0.5774],\n",
            "       grad_fn=<EmbeddingBackward>)\n"
          ]
        }
      ],
      "source": [
        "paris_vec = embedder(torch.tensor(vocab['paris']))\n",
        "print(paris_vec)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pHTJlaeYaHVd"
      },
      "source": [
        "It is interesting to use Word2Vec to look for synonyms. The following function will return `n` closest words to a given input. To find them, we compute the norm of $|w_i - v|$, where $v$ is the vector corresponding to our input word, and $w_i$ is the encoding of $i$-th word in the vocabulary. We then sort the array and return corresponding indices using `argsort`, and take first `n` elements of the list, which encode positions of closest words in the vocabulary.  "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "NlZyi-_olFar",
        "outputId": "b5dbb163-88c4-4d5a-eaf2-6751f700e98c"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "['microsoft', 'quoted', 'lp', 'rate', 'top']"
            ]
          },
          "metadata": {},
          "execution_count": 56
        }
      ],
      "source": [
        "def close_words(x, n = 5):\n",
        "  vec = embedder(torch.tensor(vocab[x]))\n",
        "  top5 = np.linalg.norm(vectors.detach().numpy() - vec.detach().numpy(), axis = 1).argsort()[:n]\n",
        "  return [ vocab.itos[x] for x in top5 ]\n",
        "\n",
        "close_words('microsoft')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "-dQq7xeAln0U",
        "outputId": "66f768c3-c248-4bfd-ce4f-c8ffc6d0dd0d"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "['basketball', 'lot', 'sinai', 'states', 'healthdaynews']"
            ]
          },
          "metadata": {},
          "execution_count": 51
        }
      ],
      "source": [
        "close_words('basketball')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "fJXqK26b29sa",
        "outputId": "78f0baba-ffd0-485a-dd87-0a12bedfd7fa"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "['funds', 'travel', 'sydney', 'japan', 'business']"
            ]
          },
          "metadata": {},
          "execution_count": 77
        }
      ],
      "source": [
        "close_words('funds')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "My0VeTDd3Ji8"
      },
      "source": [
        "## Takeaway\n",
        "\n",
        "Using clever techniques such as CBoW, we can train Word2Vec model. You may also try to train skip-gram model that is trained to predict the neighboring word given the central one, and see how well it performs. "
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "CBoW-PyTorch.ipynb",
      "provenance": []
    },
    "interpreter": {
      "hash": "16af2a8bbb083ea23e5e41c7f5787656b2ce26968575d8763f2c4b17f9cd711f"
    },
    "kernelspec": {
      "display_name": "Python 3.8.12 ('py38')",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.12"
    },
    "orig_nbformat": 4,
    "gpuClass": "standard"
  },
  "nbformat": 4,
  "nbformat_minor": 0
}